{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f0c7069a",
   "metadata": {},
   "source": [
    "## Unit Tests and Functional Tests ({Q}uality {A}ssurance) Runbook\n",
    "This is a notebook that shows how we design basic unit / functional tests for our library. Please follow this runbook to contribute more tests to ensure the quality of our code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "31e986e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "__author__ = \"Zhengxuan Wu\"\n",
    "__version__ = \"12/28/2023\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "363cee96",
   "metadata": {},
   "source": [
    "### Overview\n",
    "\n",
    "We follow generic QA framework, where we write positive and negative test cases for each function or API. For each test, it is better to cover multiple test cases and to have these cases sharing the same set of set-up. In this tutorial, we cover one test case for the subspace intervention with a simple MLP model and static interventions. Overall, we are checking the results based off our API with golden labels that are created using manual interventions. For trainable interventions, we will need to check gradients as well."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b54a060",
   "metadata": {},
   "source": [
    "### Set-up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "78594792",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "'pyvene' is not installed.\n",
      "PASS: pyvene is not installed. Testing local dev code.\n",
      "[2024-01-12 03:20:08,967] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
     ]
    }
   ],
   "source": [
    "import unittest\n",
    "from utils import *"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "502ae372",
   "metadata": {},
   "source": [
    "### Example of test case for subspace intervention\n",
    "\n",
    "Note that the mindset of developing these test cases should be assuming the code has bug, and want to develop cases where we can maximally trick the system. Here are some tricks in the following test cases:\n",
    "- **Subspace scramble**: We make the order of subspace partition to be unconventional. E,g., we have `[[1,3],[0,1]]` instead of `[[0,1],[1,3]]`. The code should be order agnostic.\n",
    "- **Uneven subspace partition**: instead of having `[[1,3],[0,1]]`, we have an uneven split between subspace.\n",
    "- **Untouched subspace**: Instead of intervening all the subspace avaliable here, we leave one neuron out by having a subspace parition like `[[0,1],[1,2]]`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "183a1f66",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "...../u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/site-packages/torch/autograd/__init__.py:200: UserWarning: Error detected in GatherBackward0. Traceback of forward call that caused the error:\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/runpy.py\", line 194, in _run_module_as_main\n",
      "    return _run_code(code, main_globals, None,\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/runpy.py\", line 87, in _run_code\n",
      "    exec(code, run_globals)\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/site-packages/ipykernel_launcher.py\", line 16, in <module>\n",
      "    app.launch_new_instance()\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/site-packages/traitlets/config/application.py\", line 846, in launch_instance\n",
      "    app.start()\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/site-packages/ipykernel/kernelapp.py\", line 677, in start\n",
      "    self.io_loop.start()\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/site-packages/tornado/platform/asyncio.py\", line 199, in start\n",
      "    self.asyncio_loop.run_forever()\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/asyncio/base_events.py\", line 570, in run_forever\n",
      "    self._run_once()\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/asyncio/base_events.py\", line 1859, in _run_once\n",
      "    handle._run()\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/asyncio/events.py\", line 81, in _run\n",
      "    self._context.run(self._callback, *self._args)\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/site-packages/ipykernel/kernelbase.py\", line 457, in dispatch_queue\n",
      "    await self.process_one()\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/site-packages/ipykernel/kernelbase.py\", line 446, in process_one\n",
      "    await dispatch(*args)\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/site-packages/ipykernel/kernelbase.py\", line 353, in dispatch_shell\n",
      "    await result\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/site-packages/ipykernel/kernelbase.py\", line 648, in execute_request\n",
      "    reply_content = await reply_content\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/site-packages/ipykernel/ipkernel.py\", line 353, in do_execute\n",
      "    res = shell.run_cell(code, store_history=store_history, silent=silent)\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/site-packages/ipykernel/zmqshell.py\", line 533, in run_cell\n",
      "    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/site-packages/IPython/core/interactiveshell.py\", line 2901, in run_cell\n",
      "    result = self._run_cell(\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/site-packages/IPython/core/interactiveshell.py\", line 2947, in _run_cell\n",
      "    return runner(coro)\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/site-packages/IPython/core/async_helpers.py\", line 68, in _pseudo_sync_runner\n",
      "    coro.send(None)\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/site-packages/IPython/core/interactiveshell.py\", line 3172, in run_cell_async\n",
      "    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/site-packages/IPython/core/interactiveshell.py\", line 3364, in run_ast_nodes\n",
      "    if (await self.run_code(code, result,  async_=asy)):\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/site-packages/IPython/core/interactiveshell.py\", line 3444, in run_code\n",
      "    exec(code_obj, self.user_global_ns, self.user_ns)\n",
      "  File \"/tmp/wuzhengx/ipykernel_2483492/1746668463.py\", line 318, in <module>\n",
      "    runner.run(suite())\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/unittest/runner.py\", line 176, in run\n",
      "    test(result)\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/unittest/suite.py\", line 84, in __call__\n",
      "    return self.run(*args, **kwds)\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/unittest/suite.py\", line 122, in run\n",
      "    test(result)\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/unittest/case.py\", line 736, in __call__\n",
      "    return self.run(*args, **kwds)\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/unittest/case.py\", line 676, in run\n",
      "    self._callTestMethod(testMethod)\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/unittest/case.py\", line 633, in _callTestMethod\n",
      "    method()\n",
      "  File \"/tmp/wuzhengx/ipykernel_2483492/1746668463.py\", line 284, in test_no_intervention_link_negative\n",
      "    _, our_out_overwrite = intervenable(\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1501, in _call_impl\n",
      "    return forward_call(*args, **kwargs)\n",
      "  File \"/juice/scr/wuzhengx/pyvene/tests/../pyvene/models/intervenable_base.py\", line 1191, in forward\n",
      "    counterfactual_outputs = self.model(**base)\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1501, in _call_impl\n",
      "    return forward_call(*args, **kwargs)\n",
      "  File \"/juice/scr/wuzhengx/pyvene/tests/../pyvene/models/mlp/modelings_mlp.py\", line 128, in forward\n",
      "    mlp_outputs = self.mlp(\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1501, in _call_impl\n",
      "    return forward_call(*args, **kwargs)\n",
      "  File \"/juice/scr/wuzhengx/pyvene/tests/../pyvene/models/mlp/modelings_mlp.py\", line 95, in forward\n",
      "    hidden_states = block(hidden_states)\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1501, in _call_impl\n",
      "    return forward_call(*args, **kwargs)\n",
      "  File \"/juice/scr/wuzhengx/pyvene/tests/../pyvene/models/mlp/modelings_mlp.py\", line 57, in forward\n",
      "    return self.dropout(self.act(self.ff1(hidden_states)))\n",
      "  File \"/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-bootleg/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1545, in _call_impl\n",
      "    hook_result = hook(self, args, kwargs, result)\n",
      "  File \"/juice/scr/wuzhengx/pyvene/tests/../pyvene/models/intervenable_base.py\", line 799, in hook_callback\n",
      "    selected_output = self._gather_intervention_output(\n",
      "  File \"/juice/scr/wuzhengx/pyvene/tests/../pyvene/models/intervenable_base.py\", line 562, in _gather_intervention_output\n",
      "    selected_output = gather_neurons(\n",
      "  File \"/juice/scr/wuzhengx/pyvene/tests/../pyvene/models/modeling_utils.py\", line 258, in gather_neurons\n",
      "    tensor_output = torch.gather(\n",
      " (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.)\n",
      "  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass\n",
      "."
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=== Test Suite: SubspaceInterventionWithMLPTestCase ===\n",
      "loaded model\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "----------------------------------------------------------------------\n",
      "Ran 6 tests in 0.180s\n",
      "\n",
      "OK\n"
     ]
    }
   ],
   "source": [
    "class SubspaceInterventionWithMLPTestCase(unittest.TestCase):\n",
    "    @classmethod\n",
    "    def setUpClass(self):\n",
    "        print(\"=== Test Suite: SubspaceInterventionWithMLPTestCase ===\")\n",
    "        self.config, self.tokenizer, self.mlp = create_mlp_classifier(\n",
    "            MLPConfig(\n",
    "                h_dim=3, n_layer=1, pdrop=0.0, \n",
    "                include_bias=False, squeeze_output=False\n",
    "            )\n",
    "        )\n",
    "\n",
    "        self.test_subspace_intervention_link_intervenable_config = IntervenableConfig(\n",
    "            intervenable_model_type=type(self.mlp),\n",
    "            intervenable_representations=[\n",
    "                IntervenableRepresentationConfig(\n",
    "                    0,\n",
    "                    \"mlp_activation\",\n",
    "                    \"pos\",  # mlp layer creates a single token reprs\n",
    "                    1,\n",
    "                    subspace_partition=[\n",
    "                        [1, 3],\n",
    "                        [0, 1],\n",
    "                    ],  # partition into two sets of subspaces\n",
    "                    intervention_link_key=0,  # linked ones target the same subspace\n",
    "                ),\n",
    "                IntervenableRepresentationConfig(\n",
    "                    0,\n",
    "                    \"mlp_activation\",\n",
    "                    \"pos\",  # mlp layer creates a single token reprs\n",
    "                    1,\n",
    "                    subspace_partition=[\n",
    "                        [1, 3],\n",
    "                        [0, 1],\n",
    "                    ],  # partition into two sets of subspaces\n",
    "                    intervention_link_key=0,  # linked ones target the same subspace\n",
    "                ),\n",
    "            ],\n",
    "            intervenable_interventions_type=VanillaIntervention,\n",
    "        )\n",
    "\n",
    "        self.test_subspace_no_intervention_link_intervenable_config = (\n",
    "            IntervenableConfig(\n",
    "                intervenable_model_type=type(self.mlp),\n",
    "                intervenable_representations=[\n",
    "                    IntervenableRepresentationConfig(\n",
    "                        0,\n",
    "                        \"mlp_activation\",\n",
    "                        \"pos\",  # mlp layer creates a single token reprs\n",
    "                        1,\n",
    "                        subspace_partition=[\n",
    "                            [0, 1],\n",
    "                            [1, 3],\n",
    "                        ],  # partition into two sets of subspaces\n",
    "                    ),\n",
    "                    IntervenableRepresentationConfig(\n",
    "                        0,\n",
    "                        \"mlp_activation\",\n",
    "                        \"pos\",  # mlp layer creates a single token reprs\n",
    "                        1,\n",
    "                        subspace_partition=[\n",
    "                            [0, 1],\n",
    "                            [1, 3],\n",
    "                        ],  # partition into two sets of subspaces\n",
    "                    ),\n",
    "                ],\n",
    "                intervenable_interventions_type=VanillaIntervention,\n",
    "            )\n",
    "        )\n",
    "\n",
    "        self.test_subspace_no_intervention_link_trainable_intervenable_config = (\n",
    "            IntervenableConfig(\n",
    "                intervenable_model_type=type(self.mlp),\n",
    "                intervenable_representations=[\n",
    "                    IntervenableRepresentationConfig(\n",
    "                        0,\n",
    "                        \"mlp_activation\",\n",
    "                        \"pos\",  # mlp layer creates a single token reprs\n",
    "                        1,\n",
    "                        intervenable_low_rank_dimension=2,\n",
    "                        subspace_partition=[\n",
    "                            [0, 1],\n",
    "                            [1, 2],\n",
    "                        ],  # partition into two sets of subspaces\n",
    "                    ),\n",
    "                    IntervenableRepresentationConfig(\n",
    "                        0,\n",
    "                        \"mlp_activation\",\n",
    "                        \"pos\",  # mlp layer creates a single token reprs\n",
    "                        1,\n",
    "                        intervenable_low_rank_dimension=2,\n",
    "                        subspace_partition=[\n",
    "                            [0, 1],\n",
    "                            [1, 2],\n",
    "                        ],  # partition into two sets of subspaces\n",
    "                    ),\n",
    "                ],\n",
    "                intervenable_interventions_type=LowRankRotatedSpaceIntervention,\n",
    "            )\n",
    "        )\n",
    "\n",
    "    def test_clean_run_positive(self):\n",
    "        \"\"\"\n",
    "        Positive test case to check whether vanilla forward pass work\n",
    "        with our object.\n",
    "        \"\"\"\n",
    "        intervenable = IntervenableModel(\n",
    "            self.test_subspace_intervention_link_intervenable_config, self.mlp\n",
    "        )\n",
    "        base = {\"inputs_embeds\": torch.rand(10, 1, 3)}\n",
    "        self.assertTrue(\n",
    "            torch.allclose(ONE_MLP_CLEAN_RUN(base, self.mlp), intervenable(base)[0][0])\n",
    "        )\n",
    "\n",
    "    def test_with_subspace_positive(self):\n",
    "        \"\"\"\n",
    "        Positive test case to intervene only a set of subspace.\n",
    "        \"\"\"\n",
    "        intervenable = IntervenableModel(\n",
    "            self.test_subspace_intervention_link_intervenable_config, self.mlp\n",
    "        )\n",
    "        # golden label\n",
    "        b_s = 10\n",
    "        base = {\"inputs_embeds\": torch.rand(b_s, 1, 3)}\n",
    "        source_1 = {\"inputs_embeds\": torch.rand(b_s, 1, 3)}\n",
    "        source_2 = {\"inputs_embeds\": torch.rand(b_s, 1, 3)}\n",
    "        base_act = ONE_MLP_FETCH_W1_ACT(base, self.mlp)\n",
    "        source_1_act = ONE_MLP_FETCH_W1_ACT(source_1, self.mlp)\n",
    "        intervened_act = base_act.clone()  # relentless clone\n",
    "        intervened_act[..., 1:3] = source_1_act[..., 1:3]\n",
    "        golden_out = ONE_MLP_WITH_W1_ACT_RUN(intervened_act, self.mlp)\n",
    "\n",
    "        # our label\n",
    "        _, our_out = intervenable(\n",
    "            base,\n",
    "            [source_1, None],\n",
    "            {\"sources->base\": ([[[0]] * b_s, None], [[[0]] * b_s, None])},\n",
    "            subspaces=[[[0]] * b_s, None],\n",
    "        )\n",
    "        self.assertTrue(torch.allclose(golden_out, our_out[0]))\n",
    "\n",
    "    def test_with_subspace_negative(self):\n",
    "        \"\"\"\n",
    "        Negative test case to check input length.\n",
    "        \"\"\"\n",
    "        intervenable = IntervenableModel(\n",
    "            self.test_subspace_intervention_link_intervenable_config, self.mlp\n",
    "        )\n",
    "        # golden label\n",
    "        b_s = 10\n",
    "        base = {\"inputs_embeds\": torch.rand(b_s, 1, 3)}\n",
    "        source_1 = {\"inputs_embeds\": torch.rand(b_s, 1, 3)}\n",
    "        source_2 = {\"inputs_embeds\": torch.rand(b_s, 1, 3)}\n",
    "\n",
    "        try:\n",
    "            intervenable(\n",
    "                base,\n",
    "                [source_1],\n",
    "                {\"sources->base\": ([[[0]] * b_s], [[[0]] * b_s])},\n",
    "                subspaces=[[[0]] * b_s],\n",
    "            )\n",
    "        except ValueError:\n",
    "            pass\n",
    "        else:\n",
    "            raise AssertionError(\"ValueError was not raised\")\n",
    "\n",
    "    def test_intervention_link_positive(self):\n",
    "        \"\"\"\n",
    "        Positive test case to intervene linked subspace.\n",
    "        \"\"\"\n",
    "        intervenable = IntervenableModel(\n",
    "            self.test_subspace_intervention_link_intervenable_config, self.mlp\n",
    "        )\n",
    "        # golden label\n",
    "        b_s = 10\n",
    "        base = {\"inputs_embeds\": torch.rand(b_s, 1, 3)}\n",
    "        source_1 = {\"inputs_embeds\": torch.rand(b_s, 1, 3)}\n",
    "        source_2 = {\"inputs_embeds\": torch.rand(b_s, 1, 3)}\n",
    "        base_act = ONE_MLP_FETCH_W1_ACT(base, self.mlp)\n",
    "        source_1_act = ONE_MLP_FETCH_W1_ACT(source_1, self.mlp)\n",
    "        source_2_act = ONE_MLP_FETCH_W1_ACT(source_2, self.mlp)\n",
    "\n",
    "        # overwrite version\n",
    "        intervened_act = base_act.clone()  # relentless clone\n",
    "        intervened_act[..., 1:3] = source_2_act[..., 1:3]\n",
    "        golden_out_overwrite = ONE_MLP_WITH_W1_ACT_RUN(intervened_act, self.mlp)\n",
    "\n",
    "        # success version\n",
    "        intervened_act = base_act.clone()  # relentless clone\n",
    "        intervened_act[..., 1:3] = source_1_act[..., 1:3]\n",
    "        intervened_act[..., 0] = source_2_act[..., 0]\n",
    "        golden_out_success = ONE_MLP_WITH_W1_ACT_RUN(intervened_act, self.mlp)\n",
    "\n",
    "        # subcase where the second one accidentally overwrites the first one\n",
    "        _, our_out_overwrite = intervenable(\n",
    "            base,\n",
    "            [source_1, source_2],\n",
    "            {\"sources->base\": ([[[0]] * b_s, [[0]] * b_s], [[[0]] * b_s, [[0]] * b_s])},\n",
    "            subspaces=[[[0]] * b_s, [[0]] * b_s],\n",
    "        )\n",
    "\n",
    "        # success\n",
    "        _, our_out_success = intervenable(\n",
    "            base,\n",
    "            [source_1, source_2],\n",
    "            {\"sources->base\": ([[[0]] * b_s, [[0]] * b_s], [[[0]] * b_s, [[0]] * b_s])},\n",
    "            subspaces=[[[0]] * b_s, [[1]] * b_s],\n",
    "        )\n",
    "\n",
    "        self.assertTrue(torch.allclose(golden_out_overwrite, our_out_overwrite[0]))\n",
    "        self.assertTrue(torch.allclose(golden_out_success, our_out_success[0]))\n",
    "\n",
    "    def test_no_intervention_link_positive(self):\n",
    "        \"\"\"\n",
    "        Positive test case to intervene not linked subspace (overwrite).\n",
    "        \"\"\"\n",
    "        intervenable = IntervenableModel(\n",
    "            self.test_subspace_no_intervention_link_intervenable_config, self.mlp\n",
    "        )\n",
    "        # golden label\n",
    "        b_s = 10\n",
    "        base = {\"inputs_embeds\": torch.rand(b_s, 1, 3)}\n",
    "        source_1 = {\"inputs_embeds\": torch.rand(b_s, 1, 3)}\n",
    "        source_2 = {\"inputs_embeds\": torch.rand(b_s, 1, 3)}\n",
    "        base_act = ONE_MLP_FETCH_W1_ACT(base, self.mlp)\n",
    "        source_1_act = ONE_MLP_FETCH_W1_ACT(source_1, self.mlp)\n",
    "        source_2_act = ONE_MLP_FETCH_W1_ACT(source_2, self.mlp)\n",
    "\n",
    "        # inplace overwrite version\n",
    "        intervened_act = base_act.clone()  # relentless clone\n",
    "        intervened_act[..., 0] = source_2_act[..., 0]\n",
    "        golden_out_inplace = ONE_MLP_WITH_W1_ACT_RUN(intervened_act, self.mlp)\n",
    "\n",
    "        # overwrite version\n",
    "        intervened_act = base_act.clone()  # relentless clone\n",
    "        intervened_act[..., 0] = source_1_act[..., 0]\n",
    "        intervened_act[..., 1:3] = source_2_act[..., 1:3]\n",
    "        golden_out_overwrite = ONE_MLP_WITH_W1_ACT_RUN(intervened_act, self.mlp)\n",
    "\n",
    "        # subcase where the second one accidentally overwrites the first one\n",
    "        _, our_out_inplace = intervenable(\n",
    "            base,\n",
    "            [source_1, source_2],\n",
    "            {\"sources->base\": ([[[0]] * b_s, [[0]] * b_s], [[[0]] * b_s, [[0]] * b_s])},\n",
    "            subspaces=[[[0]] * b_s, [[0]] * b_s],\n",
    "        )\n",
    "\n",
    "        # overwrite\n",
    "        _, our_out_overwrite = intervenable(\n",
    "            base,\n",
    "            [source_1, source_2],\n",
    "            {\"sources->base\": ([[[0]] * b_s, [[0]] * b_s], [[[0]] * b_s, [[0]] * b_s])},\n",
    "            subspaces=[[[0]] * b_s, [[1]] * b_s],\n",
    "        )\n",
    "\n",
    "        self.assertTrue(torch.allclose(golden_out_inplace, our_out_inplace[0]))\n",
    "        # the following thing work but gradient will fail check negative test cases\n",
    "        self.assertTrue(torch.allclose(golden_out_overwrite, our_out_overwrite[0]))\n",
    "\n",
    "    def test_no_intervention_link_negative(self):\n",
    "        pass\n",
    "        \"\"\"\n",
    "        Negative test case to intervene not linked subspace with trainable interventions.\n",
    "        \"\"\"\n",
    "        intervenable = IntervenableModel(\n",
    "            self.test_subspace_no_intervention_link_trainable_intervenable_config,\n",
    "            self.mlp,\n",
    "        )\n",
    "        # golden label\n",
    "        b_s = 10\n",
    "        base = {\"inputs_embeds\": torch.rand(b_s, 1, 3)}\n",
    "        source_1 = {\"inputs_embeds\": torch.rand(b_s, 1, 3)}\n",
    "        source_2 = {\"inputs_embeds\": torch.rand(b_s, 1, 3)}\n",
    "        base_act = ONE_MLP_FETCH_W1_ACT(base, self.mlp)\n",
    "        source_1_act = ONE_MLP_FETCH_W1_ACT(source_1, self.mlp)\n",
    "        source_2_act = ONE_MLP_FETCH_W1_ACT(source_2, self.mlp)\n",
    "\n",
    "        # overwrite version\n",
    "        intervened_act = base_act.clone()  # relentless clone\n",
    "        intervened_act[..., 0] = source_1_act[..., 0]\n",
    "        intervened_act[..., 1] = source_2_act[..., 1]\n",
    "        golden_out_overwrite = ONE_MLP_WITH_W1_ACT_RUN(intervened_act, self.mlp)\n",
    "\n",
    "        # overwrite\n",
    "        _, our_out_overwrite = intervenable(\n",
    "            base,\n",
    "            [source_1, source_2],\n",
    "            {\"sources->base\": ([[[0]] * b_s, [[0]] * b_s], [[[0]] * b_s, [[0]] * b_s])},\n",
    "            subspaces=[[[0]] * b_s, [[1]] * b_s],\n",
    "        )\n",
    "\n",
    "        try:\n",
    "            our_out_overwrite[0].sum().backward()\n",
    "        except RuntimeError:\n",
    "            pass\n",
    "        else:\n",
    "            raise AssertionError(\"RuntimeError by torch was not raised\")\n",
    "\n",
    "\n",
    "def suite():\n",
    "    suite = unittest.TestSuite()\n",
    "    suite.addTest(SubspaceInterventionWithMLPTestCase(\"test_clean_run_positive\"))\n",
    "    suite.addTest(SubspaceInterventionWithMLPTestCase(\"test_with_subspace_positive\"))\n",
    "    suite.addTest(SubspaceInterventionWithMLPTestCase(\"test_with_subspace_negative\"))\n",
    "    suite.addTest(\n",
    "        SubspaceInterventionWithMLPTestCase(\"test_intervention_link_positive\")\n",
    "    )\n",
    "    suite.addTest(\n",
    "        SubspaceInterventionWithMLPTestCase(\"test_no_intervention_link_positive\")\n",
    "    )\n",
    "    suite.addTest(\n",
    "        SubspaceInterventionWithMLPTestCase(\"test_no_intervention_link_negative\")\n",
    "    )\n",
    "    return suite\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    runner = unittest.TextTestRunner()\n",
    "    runner.run(suite())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a6708f2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
